//    OpenVPN -- An application to securely tunnel IP networks
//               over a single port, with support for SSL/TLS-based
//               session authentication and key exchange,
//               packet encryption, packet authentication, and
//               packet compression.
//
//    Copyright (C) 2012- OpenVPN Inc.
//
//    SPDX-License-Identifier: MPL-2.0 OR AGPL-3.0-only WITH openvpn3-openssl-exception
//

#pragma once

#include <string>
#include <cstdint>
#include <ostream>
#include <tuple>
#include <utility>

#include <openvpn/common/exception.hpp>
#include <openvpn/common/rc.hpp>
#include <openvpn/common/base64.hpp>
#include <openvpn/common/socktypes.hpp>
#include <openvpn/common/endian64.hpp>
#include <openvpn/crypto/hashstr.hpp>
#include <openvpn/buffer/buffer.hpp>
#include <openvpn/random/randapi.hpp>

namespace openvpn::WebSocket {

OPENVPN_EXCEPTION(websocket_error);

class Receiver;

inline std::string accept_confirmation(DigestFactory &digest_factory,
                                       const std::string &websocket_key)
{
    static const char guid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    HashString h(digest_factory, CryptoAlgs::SHA1);
    h.update(websocket_key + guid);
    return h.final_base64();
}

class Protocol
{
  public:
    static constexpr size_t MAX_HEAD = 16;

    enum Opcode
    {
        Text = 0x1,
        Binary = 0x2,
        Close = 0x8,
        Ping = 0x9,
        Pong = 0xA,
    };

    static std::string opcode_to_string(const unsigned int opcode)
    {
        switch (opcode)
        {
        case Text:
            return "Text";
        case Binary:
            return "Binary";
        case Close:
            return "Close";
        case Ping:
            return "Ping";
        case Pong:
            return "Pong";
        default:
            return "WS-OPCODE-" + std::to_string(opcode);
        }
    }

    union MaskingKey {
      public:
        MaskingKey(std::uint32_t mask)
            : mask32(std::move(mask))
        {
        }

        void xor_buf(Buffer &buf) const
        {
            const size_t size = buf.size();
            std::uint8_t *data = buf.data();
            for (size_t i = 0; i < size; ++i)
                data[i] ^= mask8[i & 0x3];
        }

        void prepend_mask(Buffer &buf) const
        {
            buf.prepend(&mask32, sizeof(mask32));
        }

      private:
        std::uint32_t mask32;
        std::uint8_t mask8[4];
    };
};

class Status
{
  public:
    Status()
        : opcode_(0),
          fin_(false),
          close_status_code_(0)
    {
    }

    Status(unsigned int opcode,
           bool fin = true,
           uint16_t close_status_code = 0)
        : opcode_(std::move(opcode)),
          fin_(std::move(fin)),
          close_status_code_(std::move(close_status_code))
    {
    }

    Status(const Status &ref,
           const unsigned int opcode)
        : opcode_(opcode),
          fin_(ref.fin_),
          close_status_code_(ref.close_status_code_)
    {
    }

    bool defined() const
    {
        return opcode_ != 0;
    }

    unsigned int opcode() const
    {
        return opcode_;
    }

    bool fin() const
    {
        return fin_;
    }

    auto close_status_code() const
    {
        return close_status_code_;
    }

    bool operator==(const Status &rhs) const
    {
        return std::tie(opcode_, fin_, close_status_code_) == std::tie(rhs.opcode_, fin_, rhs.close_status_code_);
    }

    bool operator!=(const Status &rhs) const
    {
        return std::tie(opcode_, fin_, close_status_code_) != std::tie(rhs.opcode_, fin_, rhs.close_status_code_);
    }

    std::string to_string() const
    {
        std::string ret;

        ret.reserve(64);
        ret += "[op=";
        ret += Protocol::opcode_to_string(opcode_);
        ret += " fin=";
        ret += std::to_string(fin_);
        if (opcode_ == Protocol::Close)
        {
            ret += " status=";
            ret += std::to_string(close_status_code_);
        }
        ret += ']';
        return ret;
    }

  private:
    friend class Receiver;

    unsigned int opcode_;
    bool fin_;
    uint16_t close_status_code_;
};

class Sender
{
  public:
    Sender(StrongRandomAPI::Ptr cli_rng_arg) // only provide rng on client side
        : cli_rng(std::move(cli_rng_arg))
    {
    }

    void frame(Buffer &buf, const Status &s) const
    {
        if (s.opcode() == Protocol::Close)
        {
            const std::uint16_t cs = htons(s.close_status_code());
            buf.prepend(&cs, sizeof(cs));
        }

        const size_t payload_len = buf.size();
        if (cli_rng)
        {
            const Protocol::MaskingKey mk(cli_rng->rand_get<std::uint32_t>());
            mk.xor_buf(buf);
            mk.prepend_mask(buf);
        }
        prepend_payload_length(buf, payload_len);

        std::uint8_t head = s.opcode() & 0xF;
        if (s.fin())
            head |= 0x80;
        buf.prepend(&head, sizeof(head));

        // OPENVPN_LOG("WS SEND HEAD\n" << dump_hex(buf));
    }

  private:
    void prepend_payload_length(Buffer &buf, const size_t len) const
    {
        std::uint8_t len8;

        if (len <= 125)
            len8 = static_cast<std::uint8_t>(len);
        else if (len <= 65535)
        {
            len8 = 126;
            const std::uint16_t len16 = htons(static_cast<std::uint16_t>(len));
            buf.prepend(&len16, sizeof(len16));
        }
        else
        {
            len8 = 127;
            const std::uint64_t len64 = Endian::rev64(len);
            buf.prepend(&len64, sizeof(len64));
        }

        if (cli_rng)
            len8 |= 0x80;
        buf.prepend(&len8, sizeof(len8));
    }

    StrongRandomAPI::Ptr cli_rng;
};

class Receiver
{
  public:
    Receiver(const bool is_client_arg)
        : is_client(is_client_arg)
    {
        reset_pod();
    }

    Buffer buf_unframed()
    {
        verify_message_complete();
        if (size > buf.size())
            throw websocket_error("Receiver::buf_unframed: internal error");
        return Buffer(buf.data(), static_cast<size_t>(size), true);
    }

    // return true if message is complete
    bool complete()
    {
        // already complete?
        if (header_complete)
            return complete_();

        // we need at least 2 bytes before we can do anything
        if (buf.size() < 2)
            return false;

        // get first 2 bytes of header
        Buffer b(buf.data(), buf.size(), true);
        const std::uint8_t *head = b.read_alloc(2);
        s.opcode_ = head[0] & 0xF;
        s.fin_ = bool(head[0] & 0x80);
        if (head[0] & 0x70)
            throw websocket_error("Receiver: reserved bits are set");
        if (bool(head[1] & 0x80) == is_client)
            throw websocket_error("Receiver: bad masking direction");

        // process payload length
        const std::uint8_t pl = head[1] & 0x7f;
        if (pl <= 125)
        {
            size = pl;
        }
        else if (pl == 126)
        {
            std::uint16_t len16;
            if (b.size() < sizeof(len16))
                return false;
            b.read(&len16, sizeof(len16));
            size = ntohs(len16);
        }
        else // pl == 127
        {
            std::uint64_t len64;
            if (b.size() < sizeof(len64))
                return false;
            b.read(&len64, sizeof(len64));
            size = Endian::rev64(len64);
        }

        // read mask (server side only)
        if (!is_client)
        {
            if (b.size() < sizeof(mask))
                return false;
            b.read(&mask, sizeof(mask));
        }

        buf.advance(b.offset());
        header_complete = true;
        return complete_();
    }

    void add_buf(BufferAllocated &&inbuf)
    {
        if (!buf.allocated())
        {
            buf = std::move(inbuf);
            buf.or_flags(BufAllocFlags::GROW);
        }
        else
            buf.append(inbuf);
    }

    void reset()
    {
        verify_message_complete();
        s = Status();
        reset_buf();
        reset_pod();
    }

    Status status() const
    {
        verify_message_complete();
        return s;
    }

  private:
    void reset_buf()
    {
        if (buf.allocated())
        {
            if (size < buf.size())
            {
                buf.advance(static_cast<size_t>(size));
                buf.realign(0);
            }
            else if (size == buf.size())
                buf.clear();
            else
                throw websocket_error("Receiver::reset_buf: bad size");
        }
    }

    void reset_pod()
    {
        header_complete = false;
        message_complete = false;
        mask = 0;
        size = 0;
    }

    void verify_message_complete() const
    {
        if (!message_complete)
            throw websocket_error("Receiver: message incomplete");
    }

    bool complete_()
    {
        if (message_complete)
            return true;

        if (header_complete && size <= buf.size())
        {
            // un-xor the data on the server side only
            if (!is_client)
            {
                Buffer b(buf.data(), static_cast<size_t>(size), true);
                const Protocol::MaskingKey mk(mask);
                mk.xor_buf(b);
            }

            // get close status code
            if (s.opcode_ == Protocol::Close && size >= 2)
            {
                std::uint16_t cs;
                buf.read(&cs, sizeof(cs));
                size -= sizeof(cs);
                s.close_status_code_ = ntohs(cs);
            }

            message_complete = true;
            return true;
        }
        return false;
    }

    const bool is_client;
    bool header_complete;
    bool message_complete;
    std::uint32_t mask;
    std::uint64_t size;
    Status s;
    BufferAllocated buf;
};

namespace Client {

struct Config : public RC<thread_unsafe_refcount>
{
    typedef RCPtr<Config> Ptr;

    std::string origin;
    std::string protocol;
    StrongRandomAPI::Ptr rng;
    DigestFactory::Ptr digest_factory;

    // compression
    bool compress = false;
    size_t compress_threshold = 256;
};

class PerRequest : public RC<thread_unsafe_refcount>
{
  private:
    Config::Ptr conf;

  public:
    typedef RCPtr<PerRequest> Ptr;

    PerRequest(Config::Ptr conf_arg)
        : conf(validate_conf(std::move(conf_arg))),
          sender(conf->rng),
          receiver(true)
    {
    }

    void client_headers(std::ostream &os)
    {
        generate_websocket_key();
        os << "Sec-WebSocket-Key: " << websocket_key << "\r\n";
        os << "Sec-WebSocket-Version: 13\r\n";
        if (!conf->protocol.empty())
            os << "Sec-WebSocket-Protocol: " << conf->protocol << "\r\n";
        os << "Connection: Upgrade\r\n";
        os << "Upgrade: websocket\r\n";
        if (!conf->origin.empty())
            os << "Origin: " << conf->origin << "\r\n";
    }

    bool confirm_websocket_key(const std::string &ws_accept) const
    {
        return ws_accept == accept_confirmation(*conf->digest_factory, websocket_key);
    }

    Sender sender;
    Receiver receiver;

  private:
    static Config::Ptr validate_conf(Config::Ptr conf)
    {
        if (!conf)
            throw websocket_error("no config");
        if (!conf->digest_factory)
            throw websocket_error("no digest factory in config");
        return conf;
    }

    void generate_websocket_key()
    {
        std::uint8_t data[16];
        conf->rng->rand_bytes(data, sizeof(data));
        websocket_key = base64->encode(data, sizeof(data));
    }

    std::string websocket_key;
};

} // namespace Client

namespace Server {

struct Config : public RC<thread_unsafe_refcount>
{
    typedef RCPtr<Config> Ptr;

    std::string protocol;
    DigestFactory::Ptr digest_factory;
};

class PerRequest : public RC<thread_unsafe_refcount>
{
  private:
    Config::Ptr conf;

  public:
    typedef RCPtr<PerRequest> Ptr;

    PerRequest(Config::Ptr conf_arg)
        : conf(validate_conf(std::move(conf_arg))),
          sender(StrongRandomAPI::Ptr()),
          receiver(false)
    {
    }

    void set_websocket_key(const std::string &websocket_key)
    {
        websocket_accept = accept_confirmation(*conf->digest_factory, websocket_key);
    }

    void server_headers(std::ostream &os)
    {
        os << "Upgrade: websocket\r\n";
        os << "Connection: Upgrade\r\n";
        if (!websocket_accept.empty())
            os << "Sec-WebSocket-Accept: " << websocket_accept << "\r\n";
        if (!conf->protocol.empty())
            os << "Sec-WebSocket-Protocol: " << conf->protocol << "\r\n";
    }

    Sender sender;
    Receiver receiver;

  private:
    static Config::Ptr validate_conf(Config::Ptr conf)
    {
        if (!conf)
            throw websocket_error("no config");
        if (!conf->digest_factory)
            throw websocket_error("no digest factory in config");
        return conf;
    }

    std::string websocket_accept;
};

} // namespace Server

} // namespace openvpn::WebSocket
